{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example 03 - Official WidebandSig53 Dataset\n",
    "This notebook walks through how to use `torchsig` to generate the Official WidebandSig53 Dataset.\n",
    "\n",
    "-------------------------------------------"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.utils.visualize import MaskClassVisualizer, mask_class_to_outline, complex_spectrogram_to_magnitude\n",
    "from torchsig.transforms.target_transforms import DescToMaskClass, DescToListTuple\n",
    "from torchsig.transforms import Spectrogram, Normalize\n",
    "from torchsig.datasets.wideband_sig53 import WidebandSig53\n",
    "from torchsig.transforms.transforms import Compose\n",
    "from torchsig.datasets import conf\n",
    "from torchsig.datasets.datamodules import WidebandSig53DataModule\n",
    "from torchsig.datasets.signal_classes import sig53\n",
    "\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "-----------------------------\n",
    "## Generate the Wideband Sig53 Dataset\n",
    "To generate the WidebandSig53 dataset, several parameters are given to the imported `WidebandSig53DataModule` class. These paramters are:\n",
    "- `root` ~ A string to specify the root directory of where to generate and/or read an existing WidebandSig53 dataset\n",
    "- `train` ~ A boolean to specify if the WidebandSig53 dataset should be the training (True) or validation (False) sets\n",
    "- `qa` - A boolean to specify whether to generate a small subset of Sig53 (True), or the full dataset (False), default is True\n",
    "- `impaired` ~ A boolean to specify if the WidebandSig53 dataset should be the clean version or the impaired version\n",
    "- `transform` ~ Optionally, pass in any data transforms here if the dataset will be used in an ML training pipeline. Note: these transforms are not called during the dataset generation. The static saved dataset will always be in IQ format. The transform is only called when retrieving data examples.\n",
    "- `target_transform` ~ Optionally, pass in any target transforms here if the dataset will be used in an ML training pipeline. Note: these target transforms are not called during the dataset generation. The static saved dataset will always be saved as tuples in the LMDB dataset. The target transform is only called when retrieving data examples.\n",
    "\n",
    "A combination of the `train` and the `impaired` booleans determines which of the four (4) distinct WidebandSig53 datasets will be instantiated:\n",
    "| `impaired` | `qa` | Result |\n",
    "| ---------- | ---- | ------- |\n",
    "| `False` | `False` | Clean datasets of train=250k examples and val=25k examples |\n",
    "| `False` | `True` | Clean datasets of train=250 examples and val=250 examples |\n",
    "| `True` | `False` | Impaired datasets of train=250k examples and val=25k examples |\n",
    "| `True` | `True` | Impaired datasets of train=250 examples and val=250 examples |\n",
    "\n",
    "The final option of the impaired validation set is the dataset to be used when reporting any results with the official WidebandSig53 dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate WidebandSig53 DataModule\n",
    "from torchsig.transforms import *\n",
    "root = \"../datasets/wideband_sig53\"\n",
    "impaired = True\n",
    "qa = True\n",
    "fft_size = 512\n",
    "num_classes = 53\n",
    "batch_size = 16\n",
    "num_workers = 4\n",
    "\n",
    "transform = Compose([                \n",
    "    Spectrogram(nperseg=fft_size, noverlap=0, nfft=fft_size, mode='complex'),\n",
    "    Normalize(norm=np.inf, flatten=True),\n",
    "])\n",
    "\n",
    "target_transform = Compose([\n",
    "    DescToMaskClass(num_classes=num_classes, width=fft_size, height=fft_size),\n",
    "])\n",
    "\n",
    "datamodule = WidebandSig53DataModule(\n",
    "    root=root,\n",
    "    impaired=impaired,\n",
    "    qa=qa,\n",
    "    fft_size=fft_size,\n",
    "    num_classes=num_classes,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=num_workers\n",
    ")\n",
    "\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup(\"fit\")\n",
    "\n",
    "wideband_sig53 = datamodule.train\n",
    "\n",
    "# Retrieve a sample and print out information\n",
    "idx = np.random.randint(len(wideband_sig53))\n",
    "data, label = wideband_sig53[idx]\n",
    "print(\"Dataset length: {}\".format(len(wideband_sig53)))\n",
    "print(\"Data shape: {}\".format(data.shape))\n",
    "print(\"Label shape: {}\".format(label.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot Subset to Verify\n",
    "The `MaskClassVisualizer` can be passed a `Dataloader` and plot visualizations of the dataset. The `batch_size` of the `DataLoader` determines how many examples to plot for each iteration over the visualizer. Note that the dataset itself can be indexed and plotted sequentially using any familiar python plotting tools as an alternative plotting method to using the `spdata` `Visualizer` as shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loader = datamodule.train_dataloader()\n",
    "\n",
    "visualizer = MaskClassVisualizer(\n",
    "    data_loader=data_loader,\n",
    "    visualize_transform=complex_spectrogram_to_magnitude,\n",
    "    visualize_target_transform=mask_class_to_outline,\n",
    "    class_list=sig53.class_list\n",
    ")\n",
    "\n",
    "for figure in iter(visualizer):\n",
    "    figure.set_size_inches(16, 9)\n",
    "    plt.show()\n",
    "    break"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "----\n",
    "### Analyze Dataset\n",
    "The dataset can also be analyzed at the macro level for details such as the distribution of classes and number of signals per sample. The below analysis reads information directly from the non-target transformed tuple annotations. Since this is different than the above dataset instantiation, the dataset is re-instantiated for analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Re-instantiate the WidebandSig53 Dataset witbatch_size=1, num_workers=1, hout a target transform and without using the RFData objects\n",
    "datamodule = WidebandSig53DataModule(\n",
    "    root=root,\n",
    "    impaired=impaired,\n",
    "    qa=qa,\n",
    "    fft_size=fft_size,\n",
    "    num_classes=num_classes,\n",
    "    transform=transform,\n",
    "    target_transform=None,\n",
    "    batch_size=1,\n",
    "    num_workers=1\n",
    ")\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup(\"fit\")\n",
    "\n",
    "wideband_sig53 = datamodule.train\n",
    "\n",
    "# Loop through the dataset recording classes and SNRs\n",
    "class_counter_dict = {\n",
    "    class_name: 0 for class_name in list(wideband_sig53.class_list)\n",
    "}\n",
    "num_signals_per_sample = []\n",
    "\n",
    "for idx in tqdm(range(len(wideband_sig53))):\n",
    "    data, annotation = wideband_sig53[idx]\n",
    "    num_signals_per_sample.append(len(annotation))\n",
    "    for signal_annotation in annotation:\n",
    "        class_counter_dict[signal_annotation[\"class_name\"]] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of classes\n",
    "class_names = list(class_counter_dict.keys())\n",
    "num_classes = list(class_counter_dict.values())\n",
    "\n",
    "plt.figure(figsize=(9,9))\n",
    "plt.pie(num_classes, labels=class_names)\n",
    "plt.title(\"Class Distribution Pie Chart\")\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(11,4))\n",
    "plt.bar(class_names, num_classes)\n",
    "plt.xticks(rotation=90)\n",
    "plt.title(\"Class Distribution Bar Chart\")\n",
    "plt.xlabel(\"Modulation Class Name\")\n",
    "plt.ylabel(\"Counts\")\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "The above distribution of classes shows all OFDM signals appearing less frequently than the remaining modulations. This makes sense because OFDM signals are drawn from a random distribution of bandwidths that are inherently larger than the remaining signals, meaning fewer OFDM signals can fit into a wideband spectrum without overlapping. Additionally, the random bursty probability and durations of OFDM signals makes it less likely to occupy a wideband capture with many short-time bursts, while the remaining modulations experience this behavior at a higher probility."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of number of signals per sample\n",
    "plt.figure(figsize=(11,8))\n",
    "plt.hist(x=num_signals_per_sample, bins=np.arange(1,max(num_signals_per_sample)+1)-0.5)\n",
    "plt.title(\"Distribution of Number of Signals Per Sample\\nTotal Number: {} - Average: {} - Max: {}\".format(\n",
    "    sum(num_signals_per_sample),\n",
    "    np.mean(np.asarray(num_signals_per_sample)),\n",
    "    max(num_signals_per_sample),\n",
    "))\n",
    "plt.xlabel(\"Number of Signal Bins\")\n",
    "plt.ylabel(\"Counts\")\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "The above distribution of the number of signals per sample shows the most commonly seen sample has two signals present. The average is slightly around 4 signals per sample and the max is 26."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For additional analysis, reinstantiate the dataset without a transform, such that the RFDescriptions can be read\n",
    "datamodule = WidebandSig53DataModule(\n",
    "    root=root,\n",
    "    impaired=impaired,\n",
    "    qa=qa,\n",
    "    fft_size=fft_size,\n",
    "    num_classes=num_classes,\n",
    "    transform=None,\n",
    "    target_transform=None,\n",
    "    batch_size=1,\n",
    "    num_workers=1\n",
    ")\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup(\"fit\")\n",
    "\n",
    "wideband_sig53 =datamodule.train\n",
    "\n",
    "num_samples = len(wideband_sig53)\n",
    "snrs = []\n",
    "bandwidths = []\n",
    "durations = []\n",
    "for idx in tqdm(range(num_samples)):\n",
    "    label = wideband_sig53[idx][1]\n",
    "    for meta in label:\n",
    "        snrs.append(meta[\"snr\"])\n",
    "        bandwidths.append(meta[\"bandwidth\"])\n",
    "        durations.append(meta[\"duration\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of SNR values\n",
    "plt.figure(figsize=(11,4))\n",
    "plt.hist(x=snrs, bins=100)\n",
    "plt.title(\"SNR Distribution\")\n",
    "plt.xlabel(\"SNR Bins (dB)\")\n",
    "plt.ylabel(\"Counts\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of bandwidth values\n",
    "plt.figure(figsize=(11,4))\n",
    "plt.hist(x=bandwidths, bins=100)\n",
    "plt.title(\"Bandwidth Distribution\")\n",
    "plt.xlabel(\"BW Bins\")\n",
    "plt.ylabel(\"Counts\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of bandwidth values\n",
    "# plt.figure(figsize=(11,4))\n",
    "# plt.hist(x=durations, bins=100)\n",
    "# plt.title(\"Duration Distribution\")\n",
    "# plt.xlabel(\"Duration Bins\")\n",
    "# plt.ylabel(\"Counts\")\n",
    "# plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
